#!/usr/bin/env python

import rospy
import os
import numpy as np
import random
import time
import sys
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
from collections import defaultdict
from std_msgs.msg import Float32MultiArray
from src.turtlebot3_dqn.environment_stage_2 import Env

EPISODES = 3000


class QLearningAgent():
    def __init__(self, action_size):
        self.pub_result = rospy.Publisher('result', Float32MultiArray, queue_size=5)
        self.result = Float32MultiArray()

        self.action_size = action_size
        self.learning_rate = 0.0025
        self.discount_factor = 0.95
        self.epsilon = 0.1
        self.q_table = defaultdict(lambda: [0.0, 0.0, 0.0, 0.0, 0.0])
        self.episode_step = 6000

    def learn(self, state, action, reward, next_state):
        current_q = self.q_table[state][action]
        # using Bellman Optimality Equation to update q function
        new_q = reward + self.discount_factor * max(self.q_table[next_state])
        self.q_table[state][action] += self.learning_rate * (new_q - current_q)



    def getAction(self, state):
        if np.random.rand() <= self.epsilon:
            self.q_value = np.zeros(self.action_size)
            action = random.randrange(self.action_size)
         
        else:
            state_action =self.q_table[state]
            action = self.arg_max(state_action)
        return action


    @staticmethod
    def arg_max(state_action):
        max_index_list = []
        max_value = state_action[0]
        for index, value in enumerate(state_action):
            if value > max_value:
                #max_index_list.clear()
		del max_index_list[:]
                max_value = value
                max_index_list.append(index)
            elif value == max_value:
                max_index_list.append(index)
        return random.choice(max_index_list)


if __name__ == '__main__':
    rospy.init_node('turtlebot3_QLearning_stage_2')
    pub_result = rospy.Publisher('result', Float32MultiArray, queue_size=5)
    pub_get_action = rospy.Publisher('get_action', Float32MultiArray, queue_size=5)
    result = Float32MultiArray()
    get_action = Float32MultiArray()

    action_size = 5

    env = Env(action_size)

    agent = QLearningAgent(action_size)
    scores, episodes = [], []
    start_time = time.time()

    for e in range(EPISODES):
        done = False
        state = env.reset()
        score = 0
        for t in range(agent.episode_step):
            action = agent.getAction(str(state))

            next_state, reward, done = env.step(action)


            score += reward
            agent.learn(str(state), action, reward, str(next_state))
            state = next_state
            get_action.data = [action, score, reward]
            pub_get_action.publish(get_action)

            if t > 500:
                rospy.loginfo("Time out.")
                done = True

            if done:
                result.data = [score, np.max(agent.q_value)]
                pub_result.publish(result)
                scores.append(score)
                episodes.append(e)
                m, s = divmod(int(time.time() - start_time), 60)
                h, m = divmod(m, 60)

                rospy.loginfo('Ep: %d score: %.2f epsilon: %.2f time: %d:%02d:%02d',
                              e, score, agent.epsilon, h, m, s)
                break


